Unsupervised learning

Clustering with GMM and the EM algorithm

#install.packages("mvtnorm")
library(mvtnorm)
EMalgo <- function(X,K,n.iter=50,plot=FALSE,...){
  # Initialize
  n = nrow(X); p = ncol(X)
  P = matrix(NA,n,K)
  prop = rep(1/K,K)
  mu = rmvnorm(K,mean = colMeans(X))
  Sigma = array(0,dim = c(K,p,p))
  for (k in 1:K) Sigma[k,,] = diag(p)
  
  # The EM loop
  for (it in 1:n.iter){ cat('.')
    # E step
    for (k in 1:K) P[,k] = prop[k] * dmvnorm(X,mu[k,],Sigma[k,,])
    P = P / rowSums(P) %*% matrix(1,1,K) # Normalization
    
    # optionnal plot
    if (plot){
      plot(X,col=max.col(P)+1)
      points(mu,col=c(2,3),pch=19,cex=3)
      Sys.sleep(0.1)
    }
    
    # M step
    for (k in 1:K){
      prop[k] = sum(P[,k]) / n
      mu[k,] = 1/sum(P[,k]) * colSums(P[,k]%*%matrix(1,1,p) * X)
      Ak = P[,k]%*%matrix(1,1,p) * (X - matrix(1,n,1)%*%mu[k,])
      Bk = X - matrix(1,n,1) %*% mu[k,]
      Sigma[k,,] = 1/sum(P[,k]) * t(Ak) %*% Bk
    }
  } 
  cat('\n')
  plot(X,col=max.col(P)+1); points(mu,col=c(2,3),pch=19,cex=3)
  
  # Calculation of BIC
  loglik = 0
  for (k in 1:K) {
    Xk = X - matrix(1,n,1) %*% mu[k,]
    loglik = loglik + sum(log(prop[k]) - p/2*log(pi) 
                 - 1/2 * log(det(Sigma[k,,]))
                 - 1/2 * Xk %*% solve(Sigma[k,,]) %*% t(Xk))
  }
  bic = loglik - 1/2 * (K-1 + K*p*(p+1)) * log(n)
  list(P = P, prop = prop, mu = mu, Sigma = Sigma, 
       bic = bic, loglik = loglik)
}

Let’s start with an easy situation:

X = rbind(rmvnorm(100,c(0,0),0.1*diag(2)),
          rmvnorm(200,c(-2,2),0.2*diag(2)))
plot(X)

out = EMalgo(X,2)
## ..................................................

out$mu
##              [,1]        [,2]
## [1,] -1.993723983  2.03243307
## [2,] -0.001575338 -0.00781512
plot(X)
points(out$mu,col=c(2,3),pch=19,cex=3)

Here, we can see the effect of each iteration on the inference of the means and the estimation of cluster memberships:

out = EMalgo(X,2,plot = TRUE,n.iter = 10)
## .

## .

## .

## .

## .

## .

## .

## .

## .

## .

Let’s now consider a more difficult problem:

X = rbind(rmvnorm(100,c(0,0),0.8*diag(2)),
          rmvnorm(200,c(-2,2),0.8*diag(2)))
plot(X)

out = EMalgo(X,2)
## ..................................................

out$mu
##            [,1]       [,2]
## [1,] -1.8202362 1.89251405
## [2,]  0.2863645 0.01007436

Let’s now consider the problem of choosing \(K\):

X = rbind(rmvnorm(100,c(0,0),0.4*diag(2)),
          rmvnorm(200,c(-2,2),0.4*diag(2)))
plot(X)

par(mfrow = c(2,3))
for (k in 1:6) EMalgo(X,k,n.iter = 50)
## ..................................................
## ..................................................
## ..................................................
## ..................................................
## ..................................................
## ..................................................

Let’s try now to use BIC to select the best K:

bic = rep(NA,6)
for (k in 1:6) bic[k] = EMalgo(X,k)$bic
## ..................................................

## ..................................................

## ..................................................

## ..................................................

## ..................................................

## ..................................................

plot(bic,type='b')